#include "unification.h"


namespace lp {

	bool unify(const Functor* f1, const Functor* f2)
	{
		// No need to unroll changes since subs are temporary, call shallow_unify directly
		Subst s;
		set<id_type> b;
		return shallow_unify(f1,f2,s,b);
	}

	// Note: subs will be hard copies, but NOT expanded
	bool unify(const Functor* f1, const Functor* f2, Subst& subs)
	{
		set<id_type> b;
		return unify(f1,f2,subs,b);
	}

	bool unify(const Functor* f1, const Functor* f2, Subst& subs, set<id_type>& bindings)
	{
		if (shallow_unify(f1,f2,subs,bindings)) {
			// Create hard copies but do NOT expand
			subs.deepen();
			return true;
		} else {
			// undo all bindings
			for_each(bindings.begin(),bindings.end(),[&](id_type s){ subs.erase(s); });
			return false;
		}
	}

	bool unify_var(id_type var_id, const Functor* f, Subst& subs, set<id_type>& bindings)
	{
		if (shallow_unify_var(var_id,f,subs,bindings)) {
			subs.deepen();
			return true;
		} else {
			// Undo all bindings
			for_each(bindings.begin(),bindings.end(),[&](id_type s){ subs.erase(s); });
			return false;
		}	
	}

	bool shallow_unify(
		const Functor* f1,
		const Functor* f2,
		Subst& subs,
		set<id_type>& bindings)
	{
		//cerr << "Trying to unify: " << *f1 << " <-> " << *f2 << "\n";
		// Base cases: null pointers
		// if (!f1) return f2 == nullptr;
		// if (!f2) return false;

		// Variable bindings
		if (f1->is_variable()) return shallow_unify_var(f1->id(),f2,subs,bindings);
		if (f2->is_variable()) return shallow_unify_var(f2->id(),f1,subs,bindings);

		// Root nodes must be identical
		if (f1->id() != f2->id() || f1->arity() != f2->arity()) return false;

		// Unify subtrees
		// Compound expression
		for (Functor::arity_type i = 0; i < f1->arity(); ++i) {
			if (!shallow_unify(f1->arg(i),f2->arg(i),subs,bindings)) return false;
		}

		return true; // everything was successfully unified
	}

	bool shallow_unify_var(
		id_type var_id,
		const Functor* f2,
		Subst& subs,
		set<id_type>& bindings)
	{
		// cerr << "Trying to var_unify: " << Functor::get_data(var_id) << " -> " << *f2 << "\n";
		// f1 is a variable
		// Does f1 already have a binding?
		const Functor* f = subs.get(var_id);
		if (f) {
			return shallow_unify(f,f2,subs,bindings);
		} else if (f2->is_variable()) {
			// Check for Trivial substitution
			if (var_id == f2->id()) return true;
			// Does f2 already have a binding?
			f = subs.get(f2->id());
			if (f) return shallow_unify_var(var_id,f,subs,bindings); // yes
		}

		// Occurs Check
		if (occurs_check(var_id,f2,subs)) {
			//cerr << "Occurs check: " << Functor::get_data(var_id) << " inside " << *f2->top() << " :: " << subs << "\n";
			return false;
		}

		// Add binding
		subs.soft_add(var_id,f2);
		bindings.insert(var_id);
		return true;
	}


	bool match(const Functor* f1, const Functor* f2)
	{
		Subst s;
		set<id_type> b;
		return shallow_match(f1,f2,s,b);
	}

	bool match(const Functor* f1, const Functor* f2, Subst& subs)
	{
		set<id_type> b;
		return match(f1,f2,subs,b);
	}

	bool match(const Functor* f1, const Functor* f2, Subst& subs, set<id_type>& bindings)
	{
		if (shallow_match(f1,f2,subs,bindings)) {
			subs.create_match(); // Note: this only creates hard copies, no expansion needed when matching
			return true;
		} else {
			// undo all bindings
			for_each(bindings.begin(),bindings.end(),[&](id_type s){ subs.erase(s); });
			return false;
		}
	}

	bool shallow_match(const Functor* f1, const Functor* f2, Subst& subs, set<id_type>& bindings)
	{
		// Base cases: null pointers
		if (!f1) return f2 == nullptr;
		if (!f2) return false;

		// Variable bindings
		if (f1->is_variable()) return shallow_match_var(f1->id(),f2,subs,bindings);

		// Root nodes must be identical
		if (f1->id() != f2->id() || f1->arity() != f2->arity()) return false;

		// Compound expression
		for (Functor::arity_type i = 0; i < f1->arity(); ++i) {
			if (!shallow_match(f1->arg(i),f2->arg(i),subs,bindings)) return false;
		}

		return true; // everything was successfully matched
	}

	bool shallow_match_var(id_type var_id, const Functor* f2, Subst& subs, set<id_type>& bindings)
	{
		// f1 is a variable
		// Either f1 already has the same binding, or we add it
		const Functor* f = subs.get(var_id);
		if (f) {
			return *f == *f2;
		} else {
			// add binding
			subs.soft_add(var_id,f2);
			bindings.insert(var_id);
			return true;
		}
	}

	bool occurs_check(id_type var, const Functor* expr, const Subst& subs)
	{
		//cerr << "Occurs check, expr: " << *expr << " (" << expr->size() << ")\n";
		for (auto di = expr->depth_begin(); di != expr->depth_end(); ++di) {
			//cerr << "*di: " << *di << "\n";
			if (di->is_variable()) {
				if (di->id() == var) {
					// var occurs in expression
					//cerr << "found collision: " << Functor::get_data(di->id()) << "\n";
					return true; 
				}
				// follow variable di->symbol() if possible
				const Functor* link = subs.get(di->id());
				if (link && occurs_check(var,link,subs)) {
					//cerr << "Occurs check: " << var << " occurs in " << *link << "\n";
					return true;
				}
			}
		}
		return false;
	}


}

